FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime

ARG USER
ARG GROUP
ARG UID
ARG GID

RUN groupadd -g ${GID} ${GROUP}
RUN useradd -u ${UID} -g ${GROUP} -s /bin/bash -m ${USER} 

RUN mkdir /wd
RUN chown ${USER}:${GROUP} /wd
WORKDIR /wd

USER ${UID}:${GID}

#RUN conda init bash
#RUN conda create -n jupyter-env jupyterlab -y
#RUN echo "conda activate jupyter-env" >> /home/${USER}/.bashrc

RUN pip install jupyterlab matplotlib einops scikit-learn

EXPOSE 9000

SHELL ["/bin/bash", "--login", "-i", "-c"]
ENV SHELL=/bin/bash

CMD jupyter lab --ip 0.0.0.0 --port 9000