"""Per-product I/O helpers. Read one product's relevant fields, write descriptions back."""
import openpyxl, json, os, sys, io, argparse, tempfile
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')

BASE = os.path.dirname(os.path.abspath(__file__))
DST = os.path.join(BASE, 'merged_master_v2_updated.xlsx')
TARGETS = os.path.join(BASE, '_targets.json')
PROGRESS = os.path.join(BASE, '_progress.json')

# Column indices (0-based) of relevant non-empty parameter fields to surface to the model
HEADER_LABELS = None  # will be loaded once

def load_headers():
    global HEADER_LABELS
    if HEADER_LABELS is None:
        wb = openpyxl.load_workbook(DST, read_only=True)
        ws = wb['Sheet1']
        HEADER_LABELS = [c.value for c in next(ws.iter_rows(min_row=1, max_row=1))]
        wb.close()
    return HEADER_LABELS

def read_product(row_idx):
    """Read one row, return dict of cikkszám + non-empty fields (with current AI descriptions for length context)."""
    headers = load_headers()
    wb = openpyxl.load_workbook(DST, read_only=True)
    ws = wb['Sheet1']
    target_row = None
    for r_idx, row in enumerate(ws.iter_rows(min_row=row_idx, max_row=row_idx, values_only=True), start=row_idx):
        target_row = row
        break
    wb.close()
    if target_row is None:
        return None
    fields = {}
    for i, val in enumerate(target_row):
        if val is None or (isinstance(val, str) and not val.strip()):
            continue
        h = headers[i] or f'col_{i}'
        # Skip stock/price/category-export/redirect fields that aren't useful for description
        if h in ('Státusz', 'Nettó Ár', 'Bruttó Ár', 'Akciós Nettó Ár', 'Akciós Bruttó Ár',
                 'Akció Kezdet', 'Akció Lejárat', 'Raktárkészlet',
                 'További Raktárkészlet: Külső raktár', 'Vásárolható, ha nincs Raktáron',
                 'Változatokhoz Raktárkészlet', 'Alacsony készlet', 'Tömeg',
                 'Paraméter: Garancia||text', 'Paraméter: Arukereso.hu Export Kategória||text',
                 'Paraméter: Méretismertető||linkblank',
                 'Paraméter: Árukereső.hu Szállítási költség||text',
                 'Paraméter: Árukereső.hu Szállítási idő||text',
                 'Paraméter: DeliveryTime||text',
                 'Kép link', 'Kép kapcsolat'):
            continue
        fields[h] = str(val).strip()
    return fields

def write_descriptions(rows, short_html, long_html):
    """Write short and long descriptions to columns 104 and 105 (1-based) for the given row indexes."""
    wb = openpyxl.load_workbook(DST)  # full mode, not read_only
    ws = wb['Sheet1']
    for r in rows:
        ws.cell(row=r, column=104, value=short_html)  # 'Rövid Leírás (AI)' is index 103 -> col 104
        ws.cell(row=r, column=105, value=long_html)   # 'Hosszú Leírás (AI)' is index 104 -> col 105
    # Atomic save: save to tmp then replace
    tmp = DST + '.tmp'
    wb.save(tmp)
    wb.close()
    os.replace(tmp, DST)

def update_progress(leader, follower_cikks):
    with open(PROGRESS, encoding='utf-8') as f:
        prog = json.load(f)
    if leader not in prog['done_leaders']:
        prog['done_leaders'].append(leader)
    prog['next_index'] = prog.get('next_index', 0) + 1
    prog.setdefault('done_cikks', [])
    for c in follower_cikks:
        if c not in prog['done_cikks']:
            prog['done_cikks'].append(c)
    tmp = PROGRESS + '.tmp'
    with open(tmp, 'w', encoding='utf-8') as f:
        json.dump(prog, f, ensure_ascii=False, indent=1)
    os.replace(tmp, PROGRESS)

def get_next_target():
    """Return the next target dict from _targets.json based on progress, or None if all done."""
    with open(TARGETS, encoding='utf-8') as f:
        targets = json.load(f)
    with open(PROGRESS, encoding='utf-8') as f:
        prog = json.load(f)
    done = set(prog.get('done_leaders', []))
    for t in targets:
        if t['leader'] not in done:
            return t
    return None

def get_next_targets(n):
    """Return up to n next targets."""
    with open(TARGETS, encoding='utf-8') as f:
        targets = json.load(f)
    with open(PROGRESS, encoding='utf-8') as f:
        prog = json.load(f)
    done = set(prog.get('done_leaders', []))
    out = []
    for t in targets:
        if t['leader'] not in done:
            out.append(t)
            if len(out) >= n:
                break
    return out

if __name__ == '__main__':
    p = argparse.ArgumentParser()
    sub = p.add_subparsers(dest='cmd', required=True)
    sub.add_parser('next')
    sub_n = sub.add_parser('next_n'); sub_n.add_argument('n', type=int)
    sub_r = sub.add_parser('read'); sub_r.add_argument('row', type=int)
    sub_w = sub.add_parser('write')
    sub_w.add_argument('--rows', required=True, help='comma-separated row indexes')
    sub_w.add_argument('--short-file', required=True, help='path to file containing short description HTML')
    sub_w.add_argument('--long-file', required=True, help='path to file containing long description HTML')
    sub_w.add_argument('--leader', required=True)
    sub_w.add_argument('--cikks', required=True, help='comma-separated cikkszáms')
    sub_p = sub.add_parser('progress')
    args = p.parse_args()

    if args.cmd == 'next':
        t = get_next_target()
        print(json.dumps(t, ensure_ascii=False, indent=1) if t else 'ALL DONE')
    elif args.cmd == 'next_n':
        ts = get_next_targets(args.n)
        print(json.dumps(ts, ensure_ascii=False, indent=1))
    elif args.cmd == 'read':
        f = read_product(args.row)
        print(json.dumps(f, ensure_ascii=False, indent=1))
    elif args.cmd == 'write':
        rows = [int(r) for r in args.rows.split(',')]
        cikks = args.cikks.split(',')
        with open(args.short_file, encoding='utf-8') as f:
            short = f.read().strip()
        with open(args.long_file, encoding='utf-8') as f:
            long_ = f.read().strip()
        # length sanity check
        sl = len(short); ll = len(long_)
        print(f'short_len={sl}, long_len={ll}')
        if sl < 700 or sl > 1100:
            print(f'WARNING: short length {sl} outside 700-1100 range')
        if ll < 700 or ll > 1100:
            print(f'WARNING: long length {ll} outside 700-1100 range')
        write_descriptions(rows, short, long_)
        update_progress(args.leader, cikks)
        print(f'OK: wrote rows {rows}, leader={args.leader}, cikks={cikks}')
    elif args.cmd == 'progress':
        with open(PROGRESS, encoding='utf-8') as f:
            print(f.read())
